import os
import sys
import argparse

sys.path.append(
    os.path.realpath(os.path.join(os.path.dirname(__file__), '../..')))

from baselines.r2plus1d.dataset import FramesDataset
from baselines.r2plus1d.model import R2plus1D


def parse_args():
    parser = argparse.ArgumentParser(
        description='Train R2plus1D baseline model.')

    data_group = parser.add_argument_group('data')
    data_group.add_argument(
        '--train-dataset', '-tr', type=str,
        default='data/xiaoduHi_train_v2.pkl',
        help='Path to training dataset which is generated by '
        'scripts/prepare_data.py.')
    data_group.add_argument(
        '--test-dataset', '-te', type=str,
        default='data/xiaoduHi_test_v2.pkl',
        help='Path to test dataset which is generated by '
        'scripts/prepare_data.py.')
    data_group.add_argument(
        '--full-neg-train', type=str,
        default='data/full_neg_train_valid_i0.80_s0.10.txt',
        help='Path to text file recording paths of full '
        'negative frames for training.')
    data_group.add_argument(
        '--full-neg-test', type=str,
        default='data/full_neg_test_valid_i0.80_s0.10.txt',
        help='Path to text file recording paths of full '
        'negative frames for testing.')
    data_group.add_argument(
        '--wae-lst-pkl', type=str, default='data/raw_wae/wae_lst.pkl',
        help='Path to wae-list pickle file.')
    data_group.add_argument(
        '--root', type=str, default='data/xiaodu_clips_v2',
        help='Root directory of positive clips.')
    data_group.add_argument(
        '--neg-root', type=str, default='data/full_neg_data',
        help='Root directory of negative clips.')
    data_group.add_argument(
        '--group-by', type=str, default='WAE_id',
        help='The way to group the classes, e.g. Scenario, WAE_id.')

    model_group = parser.add_argument_group('model')
    model_group.add_argument(
        '--base-model', type=str, default='ig65m',
        help='The R(2+1)D model is based on either ig65m or kinetics.')

    train_group = parser.add_argument_group('train')
    train_group.add_argument(
        '--epochs', type=int, default=10,
        help='The number of epochs to train the core controller.')
    train_group.add_argument(
        '--lr', type=float, default=0.0001, help='Learning rate.')
    train_group.add_argument(
        '--bs', type=int, default=8, help='Batch size.')
    train_group.add_argument(
        '--save', type=str, default='save',
        help='Directory to save parameters and log files.')

    return parser.parse_args()


if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.argv.append('-h')
    args = parse_args()

    data = FramesDataset(
        args.train_dataset,
        args.test_dataset,
        args.full_neg_train,
        args.full_neg_test,
        group_by=args.group_by,
        root=args.root,
        neg_root=args.neg_root,
        wae_lst_pkl=args.wae_lst_pkl,
        batch_size=args.bs)

    print(
        f"Training dataset: {len(data.train_ds)} | Training DataLoader: {data.train_dl} \
        \nTesting dataset: {len(data.test_ds)} | Testing DataLoader: {data.test_dl}"
    )

    learner = R2plus1D(data, base_model=args.base_model)
    learner.fit(
        lr=args.lr,
        epochs=args.epochs,
        model_dir=args.save,
        save_model=True)
